import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

# 定义简单模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 5)

model = SimpleModel()

# 对模型的参数进行剪枝
prune.l1_unstructured(model.fc, name="weight", amount=0.9)

# 获取剪枝掩码
mask = model.fc.weight_mask  # 这个掩码会与 `weight` 参数形状相同
print("Pruning Mask:", mask)
print("Masked Weight:", model.fc.weight)